//===--- PILFunction.cpp - Defines the PILFunction data structure ---------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//

#include "polarphp/pil/lang/PILArgument.h"
#include "polarphp/pil/lang/PILBasicBlock.h"
#include "polarphp/pil/lang/PILFunction.h"
#include "polarphp/pil/lang/PILInstruction.h"
#include "polarphp/pil/lang/PILModule.h"
#include "polarphp/pil/lang/PILProfiler.h"
#include "polarphp/pil/lang/PILFunctionCFG.h"
#include "polarphp/pil/lang/PrettyStackTrace.h"
#include "polarphp/ast/Availability.h"
#include "polarphp/ast/GenericEnvironment.h"
#include "polarphp/ast/Module.h"
#include "polarphp/basic/OptimizationMode.h"
#include "polarphp/basic/Statistic.h"
#include "llvm/ADT/Optional.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/GraphWriter.h"
#include "clang/AST/Decl.h"

using namespace polar;
using namespace polar::lowering;

PILSpecializeAttr::PILSpecializeAttr(bool exported, SpecializationKind kind,
                                     GenericSignature specializedSig)
   : kind(kind), exported(exported), specializedSignature(specializedSig) { }

PILSpecializeAttr *PILSpecializeAttr::create(PILModule &M,
                                             GenericSignature specializedSig,
                                             bool exported,
                                             SpecializationKind kind) {
   void *buf = M.allocate(sizeof(PILSpecializeAttr), alignof(PILSpecializeAttr));
   return ::new (buf) PILSpecializeAttr(exported, kind, specializedSig);
}

void PILFunction::addSpecializeAttr(PILSpecializeAttr *Attr) {
   if (getLoweredFunctionType()->getInvocationGenericSignature()) {
      Attr->F = this;
      SpecializeAttrSet.push_back(Attr);
   }
}

PILFunction *
PILFunction::create(PILModule &M, PILLinkage linkage, StringRef name,
                    CanPILFunctionType loweredType,
                    GenericEnvironment *genericEnv, Optional<PILLocation> loc,
                    IsBare_t isBarePILFunction, IsTransparent_t isTrans,
                    IsSerialized_t isSerialized, ProfileCounter entryCount,
                    IsDynamicallyReplaceable_t isDynamic,
                    IsExactSelfClass_t isExactSelfClass,
                    IsThunk_t isThunk,
                    SubclassScope classSubclassScope, Inline_t inlineStrategy,
                    EffectsKind E, PILFunction *insertBefore,
                    const PILDebugScope *debugScope) {
   // Get a StringMapEntry for the function.  As a sop to error cases,
   // allow the name to have an empty string.
   llvm::StringMapEntry<PILFunction*> *entry = nullptr;
   if (!name.empty()) {
      entry = &*M.FunctionTable.insert(std::make_pair(name, nullptr)).first;
      PrettyStackTracePILFunction trace("creating", entry->getValue());
      assert(!entry->getValue() && "function already exists");
      name = entry->getKey();
   }

   auto fn = new (M) PILFunction(M, linkage, name, loweredType, genericEnv, loc,
                                 isBarePILFunction, isTrans, isSerialized,
                                 entryCount, isThunk, classSubclassScope,
                                 inlineStrategy, E, insertBefore, debugScope,
                                 isDynamic, isExactSelfClass);

   if (entry) entry->setValue(fn);
   return fn;
}

PILFunction::PILFunction(PILModule &Module, PILLinkage Linkage, StringRef Name,
                         CanPILFunctionType LoweredType,
                         GenericEnvironment *genericEnv,
                         Optional<PILLocation> Loc, IsBare_t isBarePILFunction,
                         IsTransparent_t isTrans, IsSerialized_t isSerialized,
                         ProfileCounter entryCount, IsThunk_t isThunk,
                         SubclassScope classSubclassScope,
                         Inline_t inlineStrategy, EffectsKind E,
                         PILFunction *InsertBefore,
                         const PILDebugScope *DebugScope,
                         IsDynamicallyReplaceable_t isDynamic,
                         IsExactSelfClass_t isExactSelfClass)
   : Module(Module), Name(Name), LoweredType(LoweredType),
     GenericEnv(genericEnv), SpecializationInfo(nullptr),
     EntryCount(entryCount),
     Availability(AvailabilityContext::alwaysAvailable()),
     Bare(isBarePILFunction), Transparent(isTrans),
     Serialized(isSerialized), Thunk(isThunk),
     ClassSubclassScope(unsigned(classSubclassScope)), GlobalInitFlag(false),
     InlineStrategy(inlineStrategy), Linkage(unsigned(Linkage)),
     HasCReferences(false), IsWeakImported(false),
     IsDynamicReplaceable(isDynamic),
     ExactSelfClass(isExactSelfClass),
     Inlined(false), Zombie(false), HasOwnership(true),
     WasDeserializedCanonical(false), IsWithoutActuallyEscapingThunk(false),
     OptMode(unsigned(OptimizationMode::NotSet)),
     EffectsKindAttr(unsigned(E)) {
   assert(!Transparent || !IsDynamicReplaceable);
   validateSubclassScope(classSubclassScope, isThunk, nullptr);
   setDebugScope(DebugScope);

   if (InsertBefore)
      Module.functions.insert(PILModule::iterator(InsertBefore), this);
   else
      Module.functions.push_back(this);

   Module.removeFromZombieList(Name);

   // Set our BB list to have this function as its parent. This enables us to
   // splice efficiently basic blocks in between functions.
   BlockList.Parent = this;
}

PILFunction::~PILFunction() {
   // If the function is recursive, a function_ref inst inside of the function
   // will give the function a non-zero ref count triggering the assertion. Thus
   // we drop all instruction references before we erase.
   // We also need to drop all references if instructions are allocated using
   // an allocator that may recycle freed memory.
   dropAllReferences();

   if (ReplacedFunction) {
      ReplacedFunction->decrementRefCount();
      ReplacedFunction = nullptr;
   }

   auto &M = getModule();
   for (auto &BB : *this) {
      for (auto I = BB.begin(), E = BB.end(); I != E;) {
         auto Inst = &*I;
         ++I;
         PILInstruction::destroy(Inst);
         // TODO: It is only safe to directly deallocate an
         // instruction if this BB is being removed in scope
         // of destructing a PILFunction.
         M.deallocateInst(Inst);
      }
      BB.InstList.clearAndLeakNodesUnsafely();
   }

   assert(RefCount == 0 &&
          "Function cannot be deleted while function_ref's still exist");
}

void PILFunction::createProfiler(AstNode Root, PILDeclRef forDecl,
                                 ForDefinition_t forDefinition) {
   assert(!Profiler && "Function already has a profiler");
   Profiler = PILProfiler::create(Module, forDefinition, Root, forDecl);
}

bool PILFunction::hasForeignBody() const {
   if (!hasClangNode()) return false;
   return PILDeclRef::isClangGenerated(getClangNode());
}

void PILFunction::numberValues(llvm::DenseMap<const PILNode*, unsigned> &
ValueToNumberMap) const {
   unsigned idx = 0;
   for (auto &BB : *this) {
      for (auto I = BB.args_begin(), E = BB.args_end(); I != E; ++I)
         ValueToNumberMap[*I] = idx++;

      for (auto &I : BB) {
         auto results = I.getResults();
         if (results.empty()) {
            ValueToNumberMap[&I] = idx++;
         } else {
            // Assign the instruction node the first result ID.
            ValueToNumberMap[&I] = idx;
            for (auto result : results) {
               ValueToNumberMap[result] = idx++;
            }
         }
      }
   }
}


AstContext &PILFunction::getAstContext() const {
   return getModule().getAstContext();
}

OptimizationMode PILFunction::getEffectiveOptimizationMode() const {
   if (OptimizationMode(OptMode) != OptimizationMode::NotSet)
      return OptimizationMode(OptMode);

   return getModule().getOptions().OptMode;
}

bool PILFunction::shouldOptimize() const {
   return getEffectiveOptimizationMode() != OptimizationMode::NoOptimization;
}

Type PILFunction::mapTypeIntoContext(Type type) const {
   return GenericEnvironment::mapTypeIntoContext(
      getGenericEnvironment(), type);
}

PILType PILFunction::mapTypeIntoContext(PILType type) const {
   if (auto *genericEnv = getGenericEnvironment())
      return genericEnv->mapTypeIntoContext(getModule(), type);
   return type;
}

PILType GenericEnvironment::mapTypeIntoContext(PILModule &M,
                                               PILType type) const {
   assert(!type.hasArchetype());

   auto genericSig = getGenericSignature()->getCanonicalSignature();
   return type.subst(M,
                     QueryInterfaceTypeSubstitutions(this),
                     LookUpConformanceInSignature(genericSig.getPointer()),
                     genericSig);
}

bool PILFunction::isNoReturnFunction() const {
   return PILType::getPrimitiveObjectType(getLoweredFunctionType())
      .isNoReturnFunction(getModule());
}

const TypeLowering &
PILFunction::getTypeLowering(AbstractionPattern orig, Type subst) {
   return getModule().Types.getTypeLowering(orig, subst,
                                            TypeExpansionContext(*this));
}

const TypeLowering &PILFunction::getTypeLowering(Type t) const {
   return getModule().Types.getTypeLowering(t, TypeExpansionContext(*this));
}

PILType
PILFunction::getLoweredType(AbstractionPattern orig, Type subst) const {
   return getModule().Types.getLoweredType(orig, subst,
                                           TypeExpansionContext(*this));
}

PILType PILFunction::getLoweredType(Type t) const {
   return getModule().Types.getLoweredType(t, TypeExpansionContext(*this));
}

PILType PILFunction::getLoweredLoadableType(Type t) const {
   auto &M = getModule();
   return M.Types.getLoweredLoadableType(t, TypeExpansionContext(*this), M);
}

const TypeLowering &PILFunction::getTypeLowering(PILType type) const {
   return getModule().Types.getTypeLowering(type, *this);
}

PILType PILFunction::getLoweredType(PILType t) const {
   return getTypeLowering(t).getLoweredType().getCategoryType(t.getCategory());
}
bool PILFunction::isTypeABIAccessible(PILType type) const {
   return getModule().isTypeABIAccessible(type, TypeExpansionContext(*this));
}

bool PILFunction::isWeakImported() const {
   // For imported functions check the Clang declaration.
   if (ClangNodeOwner)
      return ClangNodeOwner->getClangDecl()->isWeakImported();

   // For native functions check a flag on the PILFunction
   // itself.
   if (!isAvailableExternally())
      return false;

   if (isAlwaysWeakImported())
      return true;

   if (Availability.isAlwaysAvailable())
      return false;

   auto fromContext = AvailabilityContext::forDeploymentTarget(
      getAstContext());
   return !fromContext.isContainedIn(Availability);
}

PILBasicBlock *PILFunction::createBasicBlock() {
   return new (getModule()) PILBasicBlock(this, nullptr, false);
}

PILBasicBlock *PILFunction::createBasicBlockAfter(PILBasicBlock *afterBB) {
   assert(afterBB);
   return new (getModule()) PILBasicBlock(this, afterBB, /*after*/ true);
}

PILBasicBlock *PILFunction::createBasicBlockBefore(PILBasicBlock *beforeBB) {
   assert(beforeBB);
   return new (getModule()) PILBasicBlock(this, beforeBB, /*after*/ false);
}

//===----------------------------------------------------------------------===//
//                          View CFG Implementation
//===----------------------------------------------------------------------===//

#ifndef NDEBUG

static llvm::cl::opt<unsigned>
   MaxColumns("view-cfg-max-columns", llvm::cl::init(80),
              llvm::cl::desc("Maximum width of a printed node"));

namespace {
enum class LongLineBehavior { None, Truncate, Wrap };
} // end anonymous namespace
static llvm::cl::opt<LongLineBehavior>
   LLBehavior("view-cfg-long-line-behavior",
              llvm::cl::init(LongLineBehavior::Truncate),
              llvm::cl::desc("Behavior when line width is greater than the "
                             "value provided my -view-cfg-max-columns "
                             "option"),
              llvm::cl::values(
                 clEnumValN(LongLineBehavior::None, "none", "Print everything"),
                 clEnumValN(LongLineBehavior::Truncate, "truncate",
                            "Truncate long lines"),
                 clEnumValN(LongLineBehavior::Wrap, "wrap", "Wrap long lines")));

static llvm::cl::opt<bool>
   RemoveUseListComments("view-cfg-remove-use-list-comments",
                         llvm::cl::init(false),
                         llvm::cl::desc("Should use list comments be removed"));

template <typename InstTy, typename CaseValueTy>
inline CaseValueTy getCaseValueForBB(const InstTy *Inst,
                                     const PILBasicBlock *BB) {
   for (unsigned i = 0, e = Inst->getNumCases(); i != e; ++i) {
      auto P = Inst->getCase(i);
      if (P.second != BB)
         continue;
      return P.first;
   }
   llvm_unreachable("Error! should never pass in BB that is not a successor");
}

namespace llvm {
template <>
struct DOTGraphTraits<PILFunction *> : public DefaultDOTGraphTraits {

   DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {}

   static std::string getGraphName(PILFunction *F) {
      return "CFG for '" + F->getName().str() + "' function";
   }

   static std::string getSimpleNodeLabel(PILBasicBlock *Node, PILFunction *F) {
      std::string OutStr;
      raw_string_ostream OSS(OutStr);
      const_cast<PILBasicBlock *>(Node)->printAsOperand(OSS, false);
      return OSS.str();
   }

   static std::string getCompleteNodeLabel(PILBasicBlock *Node, PILFunction *F) {
      std::string Str;
      raw_string_ostream OS(Str);

      OS << *Node;
      std::string OutStr = OS.str();
      if (OutStr[0] == '\n')
         OutStr.erase(OutStr.begin());

      // Process string output to make it nicer...
      unsigned ColNum = 0;
      unsigned LastSpace = 0;
      for (unsigned i = 0; i != OutStr.length(); ++i) {
         if (OutStr[i] == '\n') { // Left justify
            OutStr[i] = '\\';
            OutStr.insert(OutStr.begin() + i + 1, 'l');
            ColNum = 0;
            LastSpace = 0;
         } else if (RemoveUseListComments && OutStr[i] == '/' &&
                    i != (OutStr.size() - 1) && OutStr[i + 1] == '/') {
            unsigned Idx = OutStr.find('\n', i + 1); // Find end of line
            OutStr.erase(OutStr.begin() + i, OutStr.begin() + Idx);
            --i;

         } else if (ColNum == MaxColumns) { // Handle long lines.

            if (LLBehavior == LongLineBehavior::Wrap) {
               if (!LastSpace)
                  LastSpace = i;
               OutStr.insert(LastSpace, "\\l...");
               ColNum = i - LastSpace;
               LastSpace = 0;
               i += 3; // The loop will advance 'i' again.
            } else if (LLBehavior == LongLineBehavior::Truncate) {
               unsigned Idx = OutStr.find('\n', i + 1); // Find end of line
               OutStr.erase(OutStr.begin() + i, OutStr.begin() + Idx);
               --i;
            }

            // Else keep trying to find a space.
         } else
            ++ColNum;
         if (OutStr[i] == ' ')
            LastSpace = i;
      }
      return OutStr;
   }

   std::string getNodeLabel(PILBasicBlock *Node, PILFunction *Graph) {
      if (isSimple())
         return getSimpleNodeLabel(Node, Graph);
      else
         return getCompleteNodeLabel(Node, Graph);
   }

   static std::string getEdgeSourceLabel(PILBasicBlock *Node,
                                         PILBasicBlock::succblock_iterator I) {
      const PILBasicBlock *Succ = *I;
      const TermInst *Term = Node->getTerminator();

      // Label source of conditional branches with "T" or "F"
      if (auto *CBI = dyn_cast<CondBranchInst>(Term))
         return (Succ == CBI->getTrueBB()) ? "T" : "F";

      // Label source of switch edges with the associated value.
      if (auto *SI = dyn_cast<SwitchValueInst>(Term)) {
         if (SI->hasDefault() && SI->getDefaultBB() == Succ)
            return "def";

         std::string Str;
         raw_string_ostream OS(Str);

         PILValue I = getCaseValueForBB<SwitchValueInst, PILValue>(SI, Succ);
         OS << I; // TODO: or should we output the literal value of I?
         return OS.str();
      }

      if (auto *SEIB = dyn_cast<SwitchEnumInst>(Term)) {
         std::string Str;
         raw_string_ostream OS(Str);

         EnumElementDecl *E =
            getCaseValueForBB<SwitchEnumInst, EnumElementDecl *>(SEIB, Succ);
         OS << E->getFullName();
         return OS.str();
      }

      if (auto *SEIB = dyn_cast<SwitchEnumAddrInst>(Term)) {
         std::string Str;
         raw_string_ostream OS(Str);

         EnumElementDecl *E =
            getCaseValueForBB<SwitchEnumAddrInst, EnumElementDecl *>(SEIB, Succ);
         OS << E->getFullName();
         return OS.str();
      }

      if (auto *DMBI = dyn_cast<DynamicMethodBranchInst>(Term))
         return (Succ == DMBI->getHasMethodBB()) ? "T" : "F";

      if (auto *CCBI = dyn_cast<CheckedCastBranchInst>(Term))
         return (Succ == CCBI->getSuccessBB()) ? "T" : "F";

      if (auto *CCBI = dyn_cast<CheckedCastValueBranchInst>(Term))
         return (Succ == CCBI->getSuccessBB()) ? "T" : "F";

      if (auto *CCBI = dyn_cast<CheckedCastAddrBranchInst>(Term))
         return (Succ == CCBI->getSuccessBB()) ? "T" : "F";

      return "";
   }
};
} // namespace llvm
#endif

#ifndef NDEBUG
static llvm::cl::opt<std::string>
   TargetFunction("view-cfg-only-for-function", llvm::cl::init(""),
                  llvm::cl::desc("Only print out the cfg for this function"));
#endif

static void viewCFGHelper(const PILFunction* f, bool skipBBContents) {
/// When asserts are disabled, this should be a NoOp.
#ifndef NDEBUG
   // If we have a target function, only print that function out.
   if (!TargetFunction.empty() && !(f->getName().str() == TargetFunction))
      return;

   ViewGraph(const_cast<PILFunction *>(f), "cfg" + f->getName().str(),
      /*shortNames=*/skipBBContents);
#endif
}

void PILFunction::viewCFG() const {
   viewCFGHelper(this, /*skipBBContents=*/false);
}

void PILFunction::viewCFGOnly() const {
   viewCFGHelper(this, /*skipBBContents=*/true);
}


bool PILFunction::hasSelfMetadataParam() const {
   auto paramTypes = getConventions().getParameterPILTypes();
   if (paramTypes.empty())
      return false;

   auto silTy = *std::prev(paramTypes.end());
   if (!silTy.isObject())
      return false;

   auto selfTy = silTy.getAstType();

   if (auto metaTy = dyn_cast<MetatypeType>(selfTy)) {
      selfTy = metaTy.getInstanceType();
      if (auto dynamicSelfTy = dyn_cast<DynamicSelfType>(selfTy))
         selfTy = dynamicSelfTy.getSelfType();
   }

   return !!selfTy.getClassOrBoundGenericClass();
}

bool PILFunction::hasName(const char *Name) const {
   return getName() == Name;
}

/// Returns true if this function can be referenced from a fragile function
/// body.
bool PILFunction::hasValidLinkageForFragileRef() const {
   // Fragile functions can reference 'static inline' functions imported
   // from C.
   if (hasForeignBody())
      return true;

   // If we can inline it, we can reference it.
   if (hasValidLinkageForFragileInline())
      return true;

   // If the containing module has been serialized already, we no longer
   // enforce any invariants.
   if (getModule().isSerialized())
      return true;

   // If the function has a subclass scope that limits its visibility outside
   // the module despite its linkage, we cannot reference it.
   if (getClassSubclassScope() == SubclassScope::Resilient &&
       isAvailableExternally())
      return false;

   // Otherwise, only public functions can be referenced.
   return hasPublicVisibility(getLinkage());
}

bool
PILFunction::isPossiblyUsedExternally() const {
   auto linkage = getLinkage();

   // Hidden functions may be referenced by other C code in the linkage unit.
   if (linkage == PILLinkage::Hidden && hasCReferences())
      return true;

   if (ReplacedFunction)
      return true;

   return polar::isPossiblyUsedExternally(linkage, getModule().isWholeModule());
}

bool PILFunction::isExternallyUsedSymbol() const {
   return polar::isPossiblyUsedExternally(getEffectiveSymbolLinkage(),
                                          getModule().isWholeModule());
}

void PILFunction::convertToDeclaration() {
   assert(isDefinition() && "Can only convert definitions to declarations");
   dropAllReferences();
   getBlocks().clear();
}

SubstitutionMap PILFunction::getForwardingSubstitutionMap() {
   if (ForwardingSubMap)
      return ForwardingSubMap;

   if (auto *env = getGenericEnvironment())
      ForwardingSubMap = env->getForwardingSubstitutionMap();

   return ForwardingSubMap;
}

bool PILFunction::shouldVerifyOwnership() const {
   return !hasSemanticsAttr("verify.ownership.sil.never");
}

// @todo
//static Identifier getIdentifierForObjCSelector(ObjCSelector selector, AstContext &Ctxt) {
//  SmallVector<char, 64> buffer;
//  auto str = selector.getString(buffer);
//  return Ctxt.getIdentifier(str);
//}

//void PILFunction::setObjCReplacement(AbstractFunctionDecl *replacedFunc) {
//  assert(ReplacedFunction == nullptr && ObjCReplacementFor.empty());
//  assert(replacedFunc != nullptr);
//  ObjCReplacementFor = getIdentifierForObjCSelector(
//      replacedFunc->getObjCSelector(), getAstContext());
//}

void PILFunction::setObjCReplacement(Identifier replacedFunc) {
   assert(ReplacedFunction == nullptr && ObjCReplacementFor.empty());
   ObjCReplacementFor = replacedFunc;
}

// See swift/Basic/Statistic.h for declaration: this enables tracing
// PILFunctions, is defined here to avoid too much layering violation / circular
// linkage dependency.

struct PILFunctionTraceFormatter : public UnifiedStatsReporter::TraceFormatter {
   void traceName(const void *Entity, raw_ostream &OS) const {
      if (!Entity)
         return;
      const PILFunction *F = static_cast<const PILFunction *>(Entity);
      F->printName(OS);
   }

   void traceLoc(const void *Entity, SourceManager *SM,
                 clang::SourceManager *CSM, raw_ostream &OS) const {
      if (!Entity)
         return;
      const PILFunction *F = static_cast<const PILFunction *>(Entity);
      if (!F->hasLocation())
         return;
      F->getLocation().getSourceRange().print(OS, *SM, false);
   }
};

static PILFunctionTraceFormatter TF;

template<>
const UnifiedStatsReporter::TraceFormatter*
FrontendStatsTracer::getTraceFormatter<const PILFunction *>() {
   return &TF;
}
